
import argparse
import time
from typing import Dict, List
import torch
import numpy as np
import sys
import os

import sys
sys.path.append('..')
import mymodel
from utils.view_pt import select_weight_file
from quant_dorefa import activation_quantize_fn
from anypacking.quant_module import HWGQ, QuantConv2d, ImageInputQ

class ConvParam: ...

def write_hls_config(model_param, path):
    name_mapping = {
        'k': 'K',
        #'s': 'S',
        #'p': 'P',
        'ich': 'IFM_CH',
        'irow': 'IFM_ROW',
        'icol': 'IFM_COL',
        'och': 'OFM_CH',
        'orow': 'OFM_ROW',
        'ocol': 'OFM_COL',
        'abit': 'IN_BIT',
        'wbit': 'W_BIT',
        'incbit': 'INC_BIT',
        'biasbit': 'BIAS_BIT',
        'simd': 'SIMD',
        'pe': 'PE',
        'lshift': 'L_SHIFT'
    }
    content = f'''/********************************************************************************
* Filename: config.h
* Date: {time.ctime()}
* Description: This file is generated by {parser.prog}
*   ptfilename: {opt.weight} 
********************************************************************************/

#ifndef _CONFIG_H_
#define _CONFIG_H_

'''
    for n, conv_param in enumerate(model_param):
        content += f'// conv_{n}\n'
        for k, v in name_mapping.items():
            if hasattr(conv_param, k): # e.g. conv_last has no incbit
                content += f'#define CONV_{n}_{v} {getattr(conv_param, k)}\n'
        content += '\n'
    content += '#endif'

    with open(path + 'config.h', 'w') as f:
        print(content, file=f)

def extract_model(in_shape):
    model_param: List[ConvParam] = []
    feature_map_shape = in_shape
    conv_cnt = 0
    conv_cur = None
    for sub_module in model.modules():
        # expect [QAct] -> [Pooling] -> Conv -> [BN] -> [Pooling], state machine mode
        if isinstance(sub_module, HWGQ) or isinstance(sub_module, ImageInputQ) or isinstance(sub_module, activation_quantize_fn):
            print('  Detected ActQ Layer', end='')
            if conv_cur is None: conv_cur = ConvParam()
            if isinstance(sub_module, HWGQ) or isinstance(sub_module, ImageInputQ):
                conv_cur.abit = sub_module.bit
                conv_cur.astep = sub_module.step
            else:
                conv_cur.abit = sub_module.a_bit
                conv_cur.astep = 1/2**conv_cur.abit
            
            conv_cur.actq_class = type(sub_module).__name__
            print(f', abit {conv_cur.abit}, astep {conv_cur.astep}, class {conv_cur.actq_class}')

            if conv_cnt: # previous.obit = cur.abit
                model_param[conv_cnt-1].obit = conv_cur.abit
                model_param[conv_cnt-1].ostep = conv_cur.astep
            
        elif isinstance(sub_module, torch.nn.Conv2d):
            if conv_cur is None: conv_cur = ConvParam()
            conv_cur.n = conv_cnt
            print('Extract conv_%d'%conv_cnt, end='')

            conv_cur.k = sub_module.kernel_size[0]
            conv_cur.s = sub_module.stride[0]
            conv_cur.p = sub_module.padding[0]
            conv_cur.ich = sub_module.in_channels
            conv_cur.och = sub_module.out_channels
            conv_cur.groups = sub_module.groups if hasattr(sub_module, 'groups') else 1
            conv_cur.irow = feature_map_shape[1]
            conv_cur.icol = feature_map_shape[2]
            
            feature_map_shape[0] = sub_module.out_channels
            feature_map_shape[1] = (feature_map_shape[1] + 2 * sub_module.padding[0] - sub_module.kernel_size[0]) // sub_module.stride[0] + 1
            feature_map_shape[2] = (feature_map_shape[2] + 2 * sub_module.padding[0] - sub_module.kernel_size[0]) // sub_module.stride[0] + 1
            conv_cur.orow = feature_map_shape[1]
            conv_cur.ocol = feature_map_shape[2]

            if sub_module.bias is not None:
                conv_cur.convbias = sub_module.bias.detach().numpy()
                print(', +bias', end='')

            if isinstance(sub_module, QuantConv2d): # New quant
                conv_cur.wbit = sub_module.bit
                conv_cur.w, conv_cur.wstep = sub_module.export_quant() # wstep is not QuantConv2d.step becuause of alpha

            elif type(sub_module).__name__ == 'Conv2d_Q': # Old dorefa quant
                conv_cur.wbit = sub_module.w_bit
                conv_cur.wstep = 1/2**(conv_cur.wbit-1)
                weight = np.tanh(sub_module.weight.detach().numpy())
                weight = weight / np.max(np.abs(weight))
                n = 2**(conv_cur.wbit-1)
                weight_q = weight * n
                weight_q = np.clip(np.round(weight_q),-n, n-1)
                weight_q = weight_q.astype(np.int32)
                conv_cur.w = weight_q
            else:
                raise NotImplementedError(sub_module)
            print(', ich {ich}, och {och}, irow {irow}, icol {icol}, ksp {k}{s}{p}, wbit {wbit}, wstep {wstep}, g {groups}'.format(**vars(conv_cur)))
            
            model_param.append(conv_cur)
            conv_cur = None
            conv_cnt += 1
        
        elif isinstance(sub_module, torch.nn.BatchNorm2d):
            print('  Detected BatchNorm2d')
            gamma = sub_module.weight
            beta = sub_module.bias
            mean = sub_module.running_mean
            var = sub_module.running_var
            eps = sub_module.eps
            
            model_param[-1].bn_w = (gamma / (torch.sqrt(var + eps))).detach().numpy()
            model_param[-1].bn_b = (beta - (mean / (torch.sqrt(var + eps)) * gamma)).detach().numpy()

        elif isinstance(sub_module, torch.nn.MaxPool2d):
            feature_map_shape[1] = feature_map_shape[1] // sub_module.kernel_size
            feature_map_shape[2] = feature_map_shape[2] // sub_module.kernel_size
            model_param[-1].max_pool = True
    
    if not hasattr(model_param[0], 'abit'): # train code rescaled [0,255] to [0,1) by /256 default
        model_param[0].abit = 8
    if not hasattr(model_param[0], 'astep'):
        model_param[0].astep = 1/256

    return model_param

def process_batchnorm(model_param):
    '''process_batchnorm(model_param)
    Merge wstep, astep, ostep scale into batchnorm, then quantize. 

    Method:
    Define MAC = Conv(w, a), out = MAC*BN_w + BN_b,
    wq = w/wstep, aq = a/astep, MACq = MAC/MACstep, outq = out/ostep.

    outq = (MAC*BN_w + BN_b) / ostep
         = MACq * (MACstep/ostep)*BN_w + BN_b/ostep
         = MACq *     inc_raw          + bias_raw
    next layer activation a' = ActQ(out), i.e. a'q = clip(round(outq))

    Quantiaztion of inc_raw & bias_raw: 
    outq_real = round((MACq*round(inc_raw*scale) + round(bias_raw*scale)) / scale)         ; where scale=2**T
              = (MACq*round(inc_raw*scale) + round(bias_raw*scale) + 0.5 * scale) // scale ; div floor
              = (MACq*        inc          +         bias          +  2**(T-1)  ) >> T     ; [!] the 2**(T-1) bias is done by hls code

    Params:
    T = (wbit-1)+abit+lshift  # This comes from dorefa quant, not optimal
    MBIT = wbit+abit+ceil(log2(sum_number))
    incbit = len(bit(inc)); biasbit = len(bit(bias))
    larger lshift is better, but MBIT+incbit<48
    '''
    lshift = 6

    for conv in model_param[:-1]:
        print(f'Process bn_{conv.n}, shape {conv.bn_w.shape},', end = ' ')

        # Merge step to BN
        conv.lshift = lshift
        MACstep = conv.wstep * conv.astep
        ostep = conv.ostep
        inc_raw = conv.bn_w * MACstep / ostep
        bias_raw = conv.bn_b / ostep
        conv.inc_raw = inc_raw
        conv.bias_raw = bias_raw

        # Quantization
        T = lshift+conv.wbit+conv.abit-1
        conv.inc = np.round(inc_raw * 2**T).astype(np.int64)
        conv.bias = np.round(bias_raw * 2**T).astype(np.int64)
        conv.lshift_T = T
        # Get bitlength
        bitlength = lambda x: 1 + int(np.abs(x).max()).bit_length()
        conv.incbit = bitlength(conv.inc)
        conv.biasbit = bitlength(conv.bias)
        print(f'incbit {conv.incbit}, biasbit {conv.biasbit}, lshift_T {conv.lshift_T}')
    
    conv_last = model_param[-1] # process lastbias
    conv_last.inc = None
    conv_last.div = 1/(conv_last.wstep * conv_last.astep)
    #conv_last.bias = np.round(conv_last.convbias * conv_last.div).astype(np.int64)
    #conv_last.bias_raw = conv_last.convbias * conv_last.div
    #conv_last.biasbit = bitlength(conv_last.bias)
    #print(f'conv_last biasbit {conv_last.biasbit}, div {conv_last.div}')

def reorder_weight(model_param, layers_simd, layers_pe, layers_actp, layers_pep):
    '''reorder_weight(model_param)
    Reorder array for hlscode.
    '''

    for conv, simd, pe, actp, pep in zip(model_param, layers_simd, layers_pe, layers_actp, layers_pep):
        print(f'Reorder conv_{conv.n}, w {conv.w.shape}', end='')
        conv.simd = simd
        conv.pe = pe
        conv.actp = actp
        conv.pep = pep

        # process batchnorm
        if conv.inc is not None:
            conv.inc = conv.inc.reshape(conv.och//conv.actp, conv.actp).T
        if hasattr(conv, 'bias') and conv.bias is not None:
            conv.bias = conv.bias.reshape(conv.och//conv.actp, conv.actp).T
        
        # process conv weight
        if conv.k == 1:
            w = conv.w    # [och, ich, kr, kc]
            g_ich = w.shape[1]
            assert conv.och%(conv.pe * conv.pep) == 0, f"conv_{conv.n}, och {conv.och}, pe {conv.pe}, pep {conv.pep}"
            assert g_ich%simd == 0, f"conv_{conv.n}, ich {g_ich}, simd {conv.simd}"

            w = w.reshape(conv.och//(conv.pe * conv.pep), conv.pe, conv.pep, g_ich//simd, simd) # [och / (pe * pep), pe, pep, ich / simd, simd]
            w = w.transpose(1,0,3,4,2)  #[pe, och / (pe * pep), ich / simd, simd, pep]
            w = w.reshape(conv.pe, -1, g_ich//simd, simd*conv.pep) # [pe, och / (pe * pep), ich / simd, simd * pep]
            w = w.reshape(conv.pe, -1, simd*conv.pep)   # hls format [pe, och/(pe * pep) * ich/simd, simd * pep]
        else:
            w = conv.w    # [och, ich, kr, kc]
            g_ich = w.shape[1]
            assert conv.och%conv.pe == 0, f"conv_{conv.n}, och {conv.och}, pe {conv.pe}"
            assert conv.k*g_ich%simd == 0, f"conv_{conv.n}, ich {g_ich}, k {conv.k}, simd {conv.simd}"

            # if conv.n==0: # first layer is different
            #    w = w.transpose(0, 2, 3, 1) # [och, kr, kc, ich]
            # else:
            w = w.transpose(0, 3, 2, 1) # [och, kc, kr, ich]

            w = w.reshape(conv.och//conv.pe, conv.pe, conv.k, conv.k*g_ich//simd, simd)
            w = w.transpose(1,2,0,3,4) # [pe, k, och/pe, k*ich/simd, simd]
            w = w.reshape(conv.pe, conv.k, -1, simd) # hls format [pe, k, och/pe*k*ich/simd, simd]          
    
        print(' ->', w.shape)

        conv.w = w

def print_ndarray_recursion(arr, str_func=str, file=sys.stdout, stop=0):
    if not hasattr(arr, '__iter__') or len(arr.shape) == stop:
        print(str_func(arr), file=file, end='')
        return
    ends = '' if (len(arr.shape)==stop+1) else '\n'
    print('{', file=file, end='')
    for i, item in enumerate(arr):
        print_ndarray_recursion(item, str_func, file, stop)
        if i!=len(arr)-1: print(',', file=file, end=ends)
    print(ends+'}', file=file, end='')

def write_hls_weights(model_param, path):
    '''write_hls_weights(model_param, path)
    Write hls weights+inc+bias array code according to numpy shape.
    '''
    f = open(path + 'weights.hpp', 'w')

    print(f'''/********************************************************************************
* Filename: weights.hpp
* Date: {time.ctime()}
* Description: This file is generated by {parser.prog}
*   ptfilename: {opt.weight} 
********************************************************************************/

#ifndef _WEIGHTS_HPP_
#define _WEIGHTS_HPP_
#include <ap_int.h>
''', file=f)

    for conv in model_param:
        n = conv.n
        def pack1d_str(arr): # x: 1d-array
            x = 0
            for v in arr[::-1]: # [!] reverse simd pack, it is related to hls implemention
                v = int(v) # use python bignumber, not np.int
                assert -1<<conv.wbit-1 <= v < 1<<conv.wbit-1, f'got v={v} while wbit={conv.wbit}'
                x=(x<<conv.wbit) + (v&(2**conv.wbit-1))
            return hex_str(x)

        if conv.k == 1:
            print(f"Write conv_{n} weight, pe {conv.pe}, pep{conv.pep}, simd {conv.simd}, actp {conv.actp}, wbit {conv.wbit}")
            print(f"// layer: {n}, PE: {conv.pe}, PEP: {conv.pep}, SIMD: {conv.simd}, ACTP: {conv.actp}, wbit: {conv.wbit}", file=f)

            # print conv weight,  merge [SIMD] value into one ap_uint
            print(f"const ap_uint<{conv.wbit * conv.pep * conv.simd}> conv_{n}_w[{conv.pe}][{conv.w.shape[1]}]=", file=f)
            hex_str = lambda x: '"' + hex(x) + '"'
            print_ndarray_recursion(conv.w, pack1d_str, f, stop=1)
            print(';', file=f)
        else:
            print(f"Write conv_{n} weight, pe {conv.pe}, simd {conv.simd}, actp {conv.actp}, wbit {conv.wbit}")
            print(f"// layer: {n}, PE: {conv.pe}, SIMD: {conv.simd}, ACTP: {conv.actp}, wbit: {conv.wbit}", file=f)

            # print conv weight,  merge [SIMD] value into one ap_uint
            print(f"const ap_uint<{conv.wbit * conv.simd}> conv_{n}_w[{conv.pe}][{conv.k}][{conv.w.shape[2]}]=", file=f)
            hex_str = lambda x: '"' + hex(x) + '"'
            print_ndarray_recursion(conv.w, pack1d_str, f, stop=1)
            print(';', file=f)

        # print inc, bias
        if conv.inc is not None:
            print(f"const ap_int<{conv.incbit}> conv_{n}_inc[{conv.actp}][{conv.och//conv.actp}]=", file=f)
            print_ndarray_recursion(conv.inc, hex_str, f)
            print(';', file=f)
        if hasattr(conv, 'bias') and conv.bias is not None:
            print(f"const ap_int<{conv.biasbit}> conv_{n}_bias[{conv.actp}][{conv.och//conv.actp}]=", file=f)
            print_ndarray_recursion(conv.bias, hex_str, f)
            print(';', file=f)
    
    print('#endif', file=f)
    f.close()

def adjust_weight(model_param):
    special_wa_bit = ((4,2),(5,3),(5,4),(5,5),(5,6),(5,7),(5,8),(7,2),(7,3)) 
    # These packing can't quantize to -2**(wbit-1)
    for conv in model_param:
        if (conv.wbit, conv.abit) in special_wa_bit:
            print(f'Adjust conv_{conv.n} wbit={conv.wbit}')
            conv.w = np.maximum(conv.w, -2**(conv.wbit-1)+1)

if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-w', '--weight', default=None, help='.pt file name in ./weights/')
    parser.add_argument('-m', '--model', default='SkyNet_FixQ', help = 'model class name in mymodel.py')
    parser.add_argument('-c', '--config-simd-pe', default='config_simd_pe_skynet', help = '.txt file in ./hls/')
    opt = parser.parse_args()
    if opt.weight is None: opt.weight = select_weight_file()

    simd_pe = np.loadtxt('hls/'+opt.config_simd_pe+'.txt', dtype=int, skiprows=1)
    dir_output = 'hls/' + opt.weight + '/'
    if not os.path.exists(dir_output): os.makedirs(dir_output)

    # load model and state_dict
    ptfile:Dict = torch.load('weights/' + opt.weight + '.pt', map_location='cpu')
    model = getattr(mymodel, opt.model)(**ptfile.setdefault('model_params', {}))
    model.load_state_dict(ptfile['model'])

    # processs
    model_param = extract_model([1, 160, 320])
    # adjust_weight(model_param)
    process_batchnorm(model_param) # get bn param before write hls config
    torch.save(model_param, dir_output + 'model_param.pkl')
    
    reorder_weight(model_param, simd_pe[:,0], simd_pe[:,1], simd_pe[:,2], simd_pe[:,3]) # get pe, simd, actp, pep param before write hls config
    write_hls_config(model_param, dir_output)
    write_hls_weights(model_param, dir_output)
